Test for Animation of EnKF States

Imports

import pandas as pd
import plotly.express as px
import sys
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sys.path.append('../dust/Projects/ABM_DA/stationsim/')
# from stationsim_gcs_model import Model
# from ensemble_kalman_filter import EnsembleKalmanFilter, EnsembleKalmanFilterType, ActiveAgentNormaliser
from stationsim_gcs_model import Model
from ensemble_kalman_filter import EnsembleKalmanFilter, EnsembleKalmanFilterType
%matplotlib inline

Functions

def __make_exit_observation_operator(population_size):
    a = np.identity(2 * population_size)
    b = np.zeros(shape=(2 * population_size, population_size))
    return np.hstack((a, b))


def __make_observation_operator(population_size, mode):
    if mode == EnsembleKalmanFilterType.STATE:
        return np.identity(2 * population_size)
    elif mode == EnsembleKalmanFilterType.DUAL_EXIT:
        return __make_exit_observation_operator(population_size)
    else:
        raise ValueError(f'Unexpected filter mode: {mode}')

        
def __make_state_vector_length(population_size, mode):
    if mode == EnsembleKalmanFilterType.STATE:
        return 2 * population_size
    elif mode == EnsembleKalmanFilterType.DUAL_EXIT:
        return 3 * population_size
    else:
        raise ValueError(f'Unexpected filter mode: {mode}')

        
def run_enkf(filter_params, model_params, normaliser, station, pickle_path, filter_id):
    # Set up filter params
    filter_params['error_normalisation'] = normaliser
    model_params['station'] = station
    
    enkf = EnsembleKalmanFilter(Model, filter_params, model_params,
                                filtering=True, benchmarking=True)
    
    while enkf.active:
        enkf.step()
    norm = normaliser.name if normaliser is not None else 'default'
    mt = station if station is not None else 'toy'
    
    
    s = pickle_path + f'filter_{mt}_{norm}_{filter_id}.pkl'
    
    with open(s, 'wb') as f:
        pickle.dump(enkf, f)
def get_rows(xs, ys, model_type, id_prefix, output):
    rows = list()
    
    for i in range(len(xs)):
        row = output.copy()
        row['model_type'] = model_type
        row['agent_id'] = f'{id_prefix}_agent_{i}'
        row['x'] = xs[i]
        row['y'] = ys[i]
        rows.append(row)
    return rows

Constants

model_width = 740
model_height = 700

Run EnKF

ensemble_size = 20
pop_size = 20
assimilation_period = 20
obs_noise_std = 1.0
mode = EnsembleKalmanFilterType.STATE
its = 200

model_params = {'pop_total': pop_size,
                'do_print': False}

# Set up filter parameters
observation_operator = __make_observation_operator(pop_size, mode)
state_vec_length = __make_state_vector_length(pop_size, mode)
data_mode = EnsembleKalmanFilterType.STATE
data_vec_length = __make_state_vector_length(pop_size, data_mode)

filter_params = {'max_iterations': its,
                 'assimilation_period': assimilation_period,
                 'ensemble_size': ensemble_size,
                 'population_size': pop_size,
                 'vanilla_ensemble_size': ensemble_size,
                 'state_vector_length': state_vec_length,
                 'data_vector_length': data_vec_length,
                 'mode': mode,
                 'H': observation_operator,
                 'R_vector': obs_noise_std * np.ones(data_vec_length),
                 'keep_results': True,
                 'run_vanilla': True,
                 'vis': False}
# Set up filter params
# filter_params['error_normalisation'] = ActiveAgentNormaliser.BASE
model_params['station'] = 'Grand_Central'

enkf = EnsembleKalmanFilter(Model, filter_params, model_params,
                            filtering=True, benchmarking=True)

while enkf.active:
    enkf.step()
../dust/Projects/ABM_DA/stationsim/ensemble_kalman_filter.py:215: RuntimeWarning: EnKF received unexpected attribute (vanilla_ensemble_size).
  warns.warn(w, RuntimeWarning)
Running Ensemble Kalman Filter...
max_iterations: 200
ensemble_size:  20
assimilation_period:    20
pop_size:   20
filter_type:    EnsembleKalmanFilterType.STATE
inclusion_type: None
ensemble_errors:    False

Process results

results = list()

for result in enkf.results:
    output = {'time': result['time']}
    
    # Observations
    xs, ys = enkf.separate_coords(result['observation'])

    obs_results = get_rows(xs, ys, 'observation', 'observation', output)
    results.extend(obs_results)
    
    # Ground truth model
    xs, ys = enkf.separate_coords(result['ground_truth'])

    base_results = get_rows(xs, ys, 'ground_truth', 'ground_truth', output)
    results.extend(base_results)

    # Benchmark model
    xs, ys = enkf.separate_coords(result['baseline'])

    base_results = get_rows(xs, ys, 'baseline', 'baseline', output)
    results.extend(base_results)
    
    # Prior ensemble mean
    xs, ys = enkf.separate_coords(result['prior'])

    prior_state_mean_results = get_rows(xs, ys, 'prior','prior', output)
    results.extend(prior_state_mean_results)

    # Posterior ensemble mean
    xs, ys = enkf.separate_coords(result['posterior'])

    posterior_state_mean_results = get_rows(xs, ys, 'posterior','posterior', output)
    results.extend(posterior_state_mean_results)

    # Prior ensemble members
    for j in range(enkf.ensemble_size):
        state_str = f'prior_{j}'
        state_output = output.copy()

        xs, ys = enkf.separate_coords(result[state_str])

        state_member_results = get_rows(xs, ys, 'prior_ensemble_member', f'prior_ensemble_member_{j}', state_output)
        results.extend(state_member_results)
    
    
    # Posterior ensemble members
    for j in range(enkf.ensemble_size):
        state_str = f'posterior_{j}'
        state_output = output.copy()

        xs, ys = enkf.separate_coords(result[state_str])

        state_member_results = get_rows(xs, ys, 'posterior_ensemble_member', f'posterior_ensemble_member_{j}', state_output)
        results.extend(state_member_results)
    
    # Destinations
    xs, ys = enkf.separate_coords(result['destination'])

    destination_results = get_rows(xs, ys, 'destination','destination', output)
    results.extend(destination_results)

    # Origins
    xs, ys = enkf.separate_coords(result['origin'])

    origin_results = get_rows(xs, ys, 'origin','origin', output)
    results.extend(origin_results)
results = pd.DataFrame(results)
results.head()
time model_type agent_id x y
0 20 observation observation_agent_0 179.201482 679.897110
1 20 observation observation_agent_1 580.208471 691.454196
2 20 observation observation_agent_2 0.555967 -2.110513
3 20 observation observation_agent_3 0.550657 -0.114622
4 20 observation observation_agent_4 -1.027456 -0.549186
def get_agent_number(row):
    agent_id = row["agent_id"]
    return int(agent_id.split("_")[-1])
test_row = {"time": 20, "model_type": "observation",
            "agent_id": "observation_agent_0",
            "x": 15, "y": 25}
print(get_agent_number(test_row))
0
results["agent_number"] = results.apply(get_agent_number, axis=1)
if "s" in list(results):
    results.drop(["s"], axis=1, inplace=True)
results["s"] = 10
results.head()
time model_type agent_id x y agent_number s
0 20 observation observation_agent_0 179.201482 679.897110 0 10
1 20 observation observation_agent_1 580.208471 691.454196 1 10
2 20 observation observation_agent_2 0.555967 -2.110513 2 10
3 20 observation observation_agent_3 0.550657 -0.114622 3 10
4 20 observation observation_agent_4 -1.027456 -0.549186 4 10
results.to_csv("./animation_results.csv", index=False)
results = pd.read_csv("./animation_results.csv")

Create animated scatter

clock_x, clock_y = 370, 275
clock_size = 56

x_l, x_h = clock_x - (clock_size/2), clock_x + (clock_size/2)
y_l, y_h = clock_y - (clock_size/2), clock_y + (clock_size/2)
f = px.scatter(results, x='x', y='y',
               animation_frame='time', animation_group='agent_id',
               color='model_type', 
               hover_name='agent_id',
               range_x=[0, model_width], range_y=[0, model_height],
               width=1.25*model_width, height=model_height)

f.add_shape(type="circle",
    xref="x", yref="y",
    fillcolor="black",
    x0=x_l, y0=y_l, x1=x_h, y1=y_h,
    line_color="black",
)


f
f.write_html("./enkf_animation.html")